import matplotlib.pyplot as plt
from bluepysnap import Circuit as SnapCircuit

import numpy as np
import bluepysnap
import matplotlib.patches as patches

import json

from mpl_toolkits.axes_grid1.inset_locator import inset_axes

from matplotlib.colors import LinearSegmentedColormap

colors = [(0, 0, 0.1), (0, 0, 1), (1, 0, 0), (1, 1, 0), (1, 1, 1)]
vmcmap = LinearSegmentedColormap.from_list("VmColorMap", colors, N=900)
colors_per_mtype = {
    "Rt_RC": "#1f77b4",
    "VPL_TC": "#ff7f02",
    "VPL_IN": "#2ca02c",
    "mc2_Column": "k",
    "mc2;VPL": "k",
    "mc2;Rt": "#1f77b4",
}


def get_data_sorted(sim, gids, ts, te, ax="y"):

    soma_report = sim.reports["soma_report"]
    soma_pop = soma_report["thalamus_neurons"]
    vdata = soma_pop.get(group={"node_id": gids}, t_start=ts, t_stop=te)

    func = lambda g: sim.circuit.nodes["thalamus_neurons"].get(g, properties=["y"]).item()
    vdata.columns = list(map(func, vdata.columns))
    vdata = vdata[sorted(vdata.columns)]

    return vdata


def plot_analysis_vertical(
    simulation, ts, te, gids_per_mtypes_dict, fig_prefix="", sort_by="y", plot_up_times=[]
):

    # collect data for plotting
    mtypes = ["VPL_TC", "Rt_RC"]

    gids_all = []
    sample = 500

    for mtype in mtypes:
        gids = gids_per_mtypes_dict[mtype]
        gids = np.random.choice(gids, sample, replace=False)

        spikes_df = simulation.spikes.filter(group={"mtype": mtype}, t_start=ts, t_stop=te)

        if not spikes_df.report.empty:
            gids_all.extend(gids)

    fig, (ax_1, ax_2, ax_3) = plt.subplots(nrows=3, figsize=(10, 4), sharex=True)

    # third row - Plot Voltage Raster
    soma_report = simulation.reports["soma_report"]
    soma_pop = soma_report["thalamus_neurons"]
    data = soma_pop.get(group={"node_id": gids}, t_start=ts, t_stop=te)
    func = lambda g: simulation.circuit.nodes["thalamus_neurons"].get(g, properties=["y"]).item()
    data.columns = list(map(func, data.columns))
    data = data[sorted(data.columns)]
    data.index = data.index.astype(int)

    minv, maxv = -90, -40
    heatmap = ax_3.imshow(
        data.T,
        interpolation="bilinear",
        aspect="auto",
        cmap=vmcmap,
        vmin=minv,
        vmax=maxv,
        extent=[ts, te, 0, len(data.T)],
    )

    ax_3.invert_yaxis()
    ax_3.set_ylabel("TC     RT", labelpad=-15)
    ax_3.set_yticks([1, len(data.T)])
    ax_3.set_yticklabels([len(data.T), 1])
    ax_3.tick_params(axis="y", labelsize=8)
    xticks = np.arange(start=plot_up_times[0] - 500, stop=te, step=2000)
    ax_3.set_xticks(xticks)
    ax_3.labelbottom = True
    ax_3.set_xlabel("Time (ms)")

    # Adding the colorbar
    # to move color bar towards the right and create a gap b/w graph and bar, incr .9 to .91
    cb_ax = inset_axes(
        ax_3,
        width="2%",  # width = 5% of parent_bbox width
        height="75%",  # height : 50%
        loc="lower left",
        bbox_to_anchor=(1.05, 0.0, 1, 1),
        bbox_transform=ax_3.transAxes,
        borderpad=0,
    )

    cbar = fig.colorbar(heatmap, orientation="vertical", cax=cb_ax)
    cbar.set_label("mV", rotation=90, labelpad=15)

    # second row - PSTH and spikes raster
    bs = 5
    psth_limit = {"Rt_RC": [0, 100], "VPL_TC": [0, 30]}
    mtype = "Rt_RC"
    spikes_df = simulation.spikes.filter(group={"mtype": mtype}, t_start=ts, t_stop=te)
    spikes_df.firing_rate_histogram(time_binsize=bs, ax=ax_2)
    ax_2.set_ylim(psth_limit[mtype])
    ax_2.set_ylabel("RT (Hz)", color=colors_per_mtype[mtype])
    ax_2.tick_params(axis="y", colors=colors_per_mtype[mtype])
    ax_2.spines.top.set_visible(False)
    ax_2.spines.left.set_color(colors_per_mtype[mtype])

    zoom_fig, zoom_ax = plt.subplots(figsize=(4, 2))
    spikes_df.firing_rate_histogram(time_binsize=bs, ax=zoom_ax)

    mtype = "VPL_TC"
    spikes_df = simulation.spikes.filter(group={"mtype": mtype}, t_start=ts, t_stop=te)
    spikes_df.firing_rate_histogram(time_binsize=bs, ax=ax_2)
    _ax = ax_2.twinx()
    _ax.set_ylim(psth_limit[mtype])
    _ax.set_ylabel("TC (Hz)", color=colors_per_mtype[mtype])
    _ax.tick_params(axis="y", colors=colors_per_mtype[mtype])
    _ax.spines.top.set_visible(False)
    _ax.spines.right.set_color(colors_per_mtype[mtype])

    spikes_df.firing_rate_histogram(time_binsize=bs, ax=zoom_ax)

    from scipy.signal import find_peaks

    report = (
        simulation.spikes["thalamus_neurons"].spike_report.filter(group={"mtype": "Rt_RC"}).report
    )
    times = report.index
    time_start = np.min(times)
    time_stop = np.max(times)
    time_binsize = 5

    bins = np.append(np.arange(time_start, time_stop, time_binsize), time_stop)
    hist, bin_edges = np.histogram(times, bins=bins)
    node_count = report[["ids", "population"]].drop_duplicates().shape[0]
    freq = 1.0 * hist / node_count / (0.001 * time_binsize)
    mask = (bin_edges[:-1] >= 9400) & (bin_edges[:-1] <= 10400)
    f = freq[mask]
    t = bins[:-1][mask]
    peaks = find_peaks(f, height=2, width=3)[0]
    zoom_ax.scatter(t[peaks], f[peaks], c="r")
    zoom_ax.axvline(t[peaks[0]], color="m", linestyle="dashed")
    zoom_ax.axvline(t[peaks[-1]], color="m", linestyle="dashed")

    zoom_ax.set_ylabel("RT/TC(Hz)")
    zoom_ax.axvspan(9500, 10000, color="grey", alpha=0.2)
    zoom_ax.set_xlabel("Time (ms)")
    zoom_ax.set_xlim(9400, 10400)
    zoom_ax.set_ylim(0, 120)
    zoom_fig.tight_layout()
    zoom_fig.savefig("freq_zoom.pdf")
    # add Up states shading
    for t in plot_up_times:
        ax_2.axvspan(t, t + 500, color="grey", alpha=0.2)
        # points = [
        #    (t, 0),
        #    (t, psth_limit["Rt_RC"][1]),
        #    (t + 500, psth_limit["Rt_RC"][1]),
        #    (t + 500, 0),
        # ]  # using Rt cause it has a higher amplitude
        # rect = patches.Polygon(points, linewidth=1, edgecolor="none", facecolor="grey", alpha=0.2)
        # ax_2.add_patch(rect)

    # first row - Plot spike raster
    gids_all = []
    for mtype in mtypes:
        gids = gids_per_mtypes_dict[mtype]
        gids = np.random.choice(gids, sample, replace=False)

        spikes_df = simulation.spikes.filter(group=gids, t_start=ts, t_stop=te)
        if not spikes_df.report.empty:
            gids_all.extend(spikes_df.report.ids)

            ax_1.scatter(
                spikes_df.report.index,
                spikes_df.report.ids,
                marker=".",
                s=0.8,
                alpha=0.8,
                linewidths=0,
                color=colors_per_mtype[mtype],
                rasterized=True,
            )

    min_val = np.floor(np.min(gids_all) / 1000) * 1000  # used to set ylimits
    max_val = np.ceil(np.max(gids_all) / 1000) * 1000

    ax_1.spines.top.set_visible(False)
    ax_1.spines.right.set_visible(False)

    ax_1.set(yticks=[])
    ax_1.set_ylim([max_val, min_val])

    ax_1.set_ylabel("TC     RT")
    ax_1.tick_params(axis="y", labelsize=8)

    # add Up states shading
    for t in plot_up_times:
        points = [(t, max_val), (t, min_val), (t + 500, min_val), (t + 500, max_val)]
        rect = patches.Polygon(points, linewidth=1, edgecolor="none", facecolor="grey", alpha=0.1)
        ax_1.add_patch(rect)
    # [t.set_visible(True) for t in ax.get_xticklabels()]

    plt.tight_layout()
    fig.savefig("voltage_raster.pdf")


if __name__ == "__main__":
    circuit_path = "/gpfs/bbp.cscs.ch/project/proj55/iavarone/releases/circuits/O1/2019-11-19_sonata_Zenodo/circuit_sonata.json"
    thal_circuit = SnapCircuit(circuit_path)

    mtypes = ["Rt_RC", "VPL_TC"]
    mtype_regions = ["mc2;Rt", "mc2;VPL"]

    gids_per_mtypes_dict = {}

    node_population = thal_circuit.nodes["thalamus_neurons"]
    for mtype, mtype_region in zip(mtypes, mtype_regions):
        mtype_gids = node_population.ids({"region": mtype_region, "mtype": mtype})
        gids_per_mtypes_dict[mtype] = mtype_gids

    sim_root = "../../simulations/08_08_Cav_variants/"

    ts, te = 1500, 20000

    colors = {
        "Ecel": "springgreen",
        "Spp": "darkorchid",
        "Runaway": "purple",
        "intermedEcel": "red",
        "intermedSpp": "blue",
        "mixed50pEcel50pSppRunaway": "green",
        "mixed75pEcel25pSppRunaway": "orange",
        "mixed25pEcel75pSppRunaway": "grey",
    }

    sim_type_list = ["variant7_mixed50pEcel50pSppRunaway"]

    up_start = [
        3500,
        4500,
        5500,
        6500,
        7500,
        8500,
        9500,
        10500,
        11500,
        12500,
        13500,
        14500,
        15500,
        16500,
        17500,
        18500,
    ]

    for index, sim_type_prefix in enumerate(sim_type_list):
        sims_list = ["_100percent_CT_up_down"]

        spike_data = {"mean_num_spikes": [], "percent_CT_activated": [], "ts": ts, "te": te}
        spike_data["ts"] = ts
        spike_data["te"] = te

        for i, sim_folder in enumerate(sims_list):
            sim_folder_full_name = sim_type_prefix + "_SUPER_LONG" + sim_folder

            sim_path = (
                sim_root + sim_type_prefix + "/" + sim_folder_full_name + "/simulation_config.json"
            )
            data = json.load(open(sim_path))

            circuit_var = data["network"].split("/")[-1]
            percent_CT_input = data["inputs"]["spikeReplay_ct_UP1"]["spike_file"].split("_")[-4]

            print(f"\nAnalysing sim for circuit {circuit_var}, percent_CT_input {percent_CT_input}")

            _sim = bluepysnap.Simulation(sim_path)
            plot_analysis_vertical(
                _sim,
                ts,
                te,
                gids_per_mtypes_dict,
                fig_prefix="",
                sort_by="y",
                plot_up_times=[
                    3500,
                    4500,
                    5500,
                    6500,
                    7500,
                    8500,
                    9500,
                    10500,
                    11500,
                    12500,
                    13500,
                    14500,
                    15500,
                    16500,
                    17500,
                    18500,
                ],
            )
